﻿#pragma kernel UpdateSoftbodyMesh

#include "MathUtils.cginc"

struct Influence
{
    int index;
    float weight;
};

struct SkinmapData
{
    int firstInfluence;
    int firstInfNumber;
    int firstParticleBindPose;
    
    int firstSkinWeight;
    int firstSkinWeightNumber;
    int firstBoneBindPose;

    int bindPoseCount;
};

struct SkeletonData
{
    int firstBone;
    int boneCount;
};

struct MeshData
{
    int firstVertex;
    int vertexCount;

    int firstTriangle;
    int triangleCount;
};

StructuredBuffer<int> particleIndices;
StructuredBuffer<int> rendererIndices; // for each vertex/particle, index of its renderer.

StructuredBuffer<float4> renderablePositions;
StructuredBuffer<quaternion> renderableOrientations;

StructuredBuffer<float4> restPositions;
StructuredBuffer<quaternion> restOrientations;
StructuredBuffer<float4> colors;

StructuredBuffer<int> skinConstraintOffsets;

StructuredBuffer<int> skinmapIndices; // for each renderer, index of its skinmap.
StructuredBuffer<int> meshIndices; // for each renderer, index of its mesh.
StructuredBuffer<int> skeletonIndices; // for each renderer, index of its skeleton.

StructuredBuffer<int> particleOffsets; // for each renderer, index of its first particle in the batch.
StructuredBuffer<int> vertexOffsets;  // for each renderer, index of its first vertex in the batch.

StructuredBuffer<SkinmapData> skinData;
StructuredBuffer<Influence> influences;
StructuredBuffer<int> influenceOffsets;
StructuredBuffer<float4x4> bindPoses;

StructuredBuffer<SkeletonData> skeletonData;
StructuredBuffer<float3> bonePos;
StructuredBuffer<quaternion> boneRot;
StructuredBuffer<float3> boneScl;

StructuredBuffer<MeshData> meshData;
StructuredBuffer<float3> positions;
StructuredBuffer<float3> normals;
StructuredBuffer<float4> tangents;

RWByteAddressBuffer vertices;

// Variables set from the CPU
uint vertexCount;
float4x4 world2Solver;

[numthreads(128, 1, 1)]
void UpdateSoftbodyMesh (uint3 id : SV_DispatchThreadID) 
{
    unsigned int i = id.x;

    if (i >= vertexCount) return;

    int rendererIndex = rendererIndices[i];

    // get skin map and mesh data:
    SkinmapData skin = skinData[skinmapIndices[rendererIndex]];
    MeshData mesh = meshData[meshIndices[rendererIndex]];
    SkeletonData skel = skeletonData[skeletonIndices[rendererIndex]];
    
    // get index of this vertex in its original mesh:
    int originalVertexIndex = i - vertexOffsets[rendererIndex];

    // get index of the vertex in the mesh batch:
    int batchedVertexIndex = mesh.firstVertex + originalVertexIndex;

    // get first influence and amount of influences for this vertex:
    int influenceStart = influenceOffsets[skin.firstInfNumber + originalVertexIndex];
    int influenceCount = influenceOffsets[skin.firstInfNumber + originalVertexIndex + 1] - influenceStart;
    
    float3 position = float3(0,0,0);
    float3 normal = float3(0,0,0);
    float4 tangent = FLOAT4_ZERO;
    float4 color = FLOAT4_ZERO;
    
    for (int k = influenceStart; k < influenceStart + influenceCount; ++k)
    {
        Influence inf = influences[skin.firstInfluence + k];
        float4x4 trfm;

        if (inf.index < skin.bindPoseCount) // bone influence:
        {
            int boneIndex = skel.firstBone + inf.index;
            int bindIndex = skin.firstParticleBindPose + inf.index;

            float4x4 bind = bindPoses[bindIndex];

            float4x4 deform = inf.index < skel.boneCount ? TRS(bonePos[boneIndex], boneRot[boneIndex], boneScl[boneIndex]) : FLOAT4X4_IDENTITY;

            // bone skinning leaves vertices in world space, so convert to solver space afterwards:
            trfm = mul(world2Solver, mul(deform, bind)); 
        }
        else // particle influence
        {
            int p = particleIndices[particleOffsets[rendererIndex] + inf.index - skin.bindPoseCount];

            float4x4 deform = mul(m_translate(FLOAT4X4_IDENTITY,renderablePositions[p].xyz), q_toMatrix(renderableOrientations[p]));
            trfm = mul(deform, bindPoses[skin.firstParticleBindPose + inf.index]);
            color += colors[p] * inf.weight;
        }

        // update vertex/normal/tangent:
        position += mul(trfm, float4(positions[batchedVertexIndex], 1)).xyz * inf.weight;
        normal += mul(trfm, float4(normals[batchedVertexIndex], 0)).xyz * inf.weight;
        tangent += float4(mul(trfm, float4(tangents[batchedVertexIndex].xyz, 0)).xyz, tangents[batchedVertexIndex].w) * inf.weight;
    }

    int base = i * 14;
    vertices.Store3( base<<2, asuint(position));
    vertices.Store3((base + 3)<<2, asuint(normal));
    vertices.Store4((base + 6)<<2, asuint(tangent));
    vertices.Store4((base + 10)<<2, asuint(color));
}